# -*- coding: utf-8 -*-
'''
Created on 2016年12月22日

@author: ZhuJiahui
'''

import numpy as np
from operator import itemgetter


def sparse_topic_vsm(topic_vsm, sp):
    '''
    分布稀疏化
    :param topic_vsm: 稠密的分布
    :param sp: 稀疏度
    :return 稀疏化的分布 (numpy 2d array)
    '''
    
    col_num = len(topic_vsm[0])

    if col_num < sp:
        return topic_vsm
    else:
        sp_topic_vsm = np.zeros((len(topic_vsm), col_num))
        for i in range(len(topic_vsm)):
            sort_index = np.argsort(topic_vsm[i])
            selected_index = sort_index[(col_num - sp) :]
            selected_element = topic_vsm[i][selected_index]
            temp_sum = np.sum(selected_element)
            changed_element = np.true_divide(selected_element, temp_sum)
            sp_topic_vsm[i][selected_index] = changed_element
        return sp_topic_vsm


def get_real_topics(sp_topic_vsm, this_word_list):
    '''
    获取真实的主题词
    :param sp_topic_vsm: 主题-词汇分布
    :param this_word_list: 词汇空间列表
    :return 主题词汇短语 (2d list)
    '''
    
    real_topics = []
    
    for j in range(len(sp_topic_vsm)):
        this_topic = []
        this_topic_weight = []
        
        for k in range(len(sp_topic_vsm[j])):
            if sp_topic_vsm[j][k] > 0.0001:
                this_topic.append(this_word_list[k])
                this_topic_weight.append(sp_topic_vsm[j][k])
            
        tt = zip(this_topic, this_topic_weight)
        tt = sorted(tt, key = itemgetter(1), reverse=True)
        this_topic = []
        for each in tt:
            this_topic.append(str(each[1]) + '*' + each[0])
            
        real_topics.append(" ".join(this_topic))
    
    return real_topics


def get_topic_proportion(sp_topic_vsm):
    '''
    获取每一个主题在所有文档中的出现频率
    :param sp_topic_vsm: 文档-主题分布
    :return 每个主题在所有文档下的计数数组(numpy 1d array)
             每个主题在所有文档中的出现频率数组(numpy 1d array)
    '''
    
    doc_num = len(sp_topic_vsm)    
    dt_flag = np.ceil(sp_topic_vsm)
    topic_count = np.sum(dt_flag, 0)  # 按列求和
    topic_proportion = np.true_divide(topic_count, doc_num)
    
    return topic_count, topic_proportion


def get_topic_rfc(sp_topic_vsm, t_delta):
    '''
    获取主题相关的feeds(文本)集合 related feed collection (rfc)
    :param sp_topic_vsm: 文档-主题分布
    :param t_delta: 相关度阈值
    :return 主题相关的feeds(文本)集合 (2d list)
    '''
    
    doc_num = len(sp_topic_vsm)
    topic_num = len(sp_topic_vsm[0])
    
    rfc = []
    for i in range(topic_num):
        this_rfc = []
        for j in range(doc_num):
            if (sp_topic_vsm[j][i] > t_delta):
                this_rfc.append(str(j))
        
        rfc.append(this_rfc)
    
    return rfc


def get_topic_sim_by_rfc(rfc1, rfc2):
    '''
    主题相关的feeds(文本)集合之间的Jaccard相似度
    :param rfc1: 主题1的集合
    :param rfc2: 主题2的集合
    :return 相似度
    '''

    up_formular = set(rfc1).intersection(rfc2)
    down_formular = set(rfc1).union(rfc2)

    if (len(down_formular)) > 0:
        similarity = np.true_divide(len(up_formular), len(down_formular))
    else:
        similarity = 0.0

    return similarity


def merge_rfc(rfc1, rfc2):
    '''
    合并主题相关文档集合
    :param rfc1: 集合1
    :param rfc2: 集合2
    :return 合并后的集合
    '''
    
    return list(set(rfc1).union(rfc2))


def merge_topic_vsm(topic_vsm1, topic_vsm2):
    '''
    纵向拼接主题向量空间(特征维数不变)
    即new = [topic_vsm1; topic_vsm2]
    :param topic_vsm1: 向量空间1
    :param topic_vsm2: 向量空间2
    :return 合并后的主题向量空间
    '''

    topic_vsm_shape1 = topic_vsm1.shape
    topic_vsm_shape2 = topic_vsm2.shape

    # 主题向量空间的合并
    if len(topic_vsm_shape1) == 1 and len(topic_vsm_shape2) == 1:
        merged_topic_vsm = np.zeros((2, topic_vsm_shape1[0]))
        merged_topic_vsm[0] = topic_vsm1
        merged_topic_vsm[1] = topic_vsm2
    elif len(topic_vsm_shape1) == 2 and len(topic_vsm_shape2) == 1:
        merged_topic_vsm = np.zeros(((topic_vsm_shape1[0] + 1), topic_vsm_shape1[1]))
        merged_topic_vsm[0 : topic_vsm_shape1[0], :] = topic_vsm1
        merged_topic_vsm[topic_vsm_shape1[0]] = topic_vsm2
    elif len(topic_vsm_shape1) == 1 and len(topic_vsm_shape2) == 2:
        merged_topic_vsm = np.zeros(((topic_vsm_shape2[0] + 1), topic_vsm_shape2[1]))
        merged_topic_vsm[0] = topic_vsm1
        merged_topic_vsm[1 : (topic_vsm_shape2[0] + 1), :] = topic_vsm2
    else:
        merged_topic_vsm = np.zeros(((topic_vsm_shape1[0] + topic_vsm_shape2[0]), topic_vsm_shape1[1]))
        merged_topic_vsm[0 : topic_vsm_shape1[0]] = topic_vsm1
        merged_topic_vsm[topic_vsm_shape1[0] : (topic_vsm_shape1[0] + topic_vsm_shape2[0]), :] = topic_vsm2
        
    return merged_topic_vsm


def merge_center_topic_by_support(center_topic1, support1, center_topic2, support2):
    '''
    主题分布合并 依据支持度
    :param center_topic1: 主题分布1
    :param support1: 支持度1
    :param center_topic2: 主题分布2
    :param support2: 支持度2
    :return 合并后的主题分布 (1d numpy array)
    '''
    
    temp_sum = 0.0
    minimum_selected = 0.00001  # 两边概率分布值大于0.001时才进行合并
    
    pro1 = np.true_divide(support1, (support1 + support2))
    pro2 = np.true_divide(support2, (support1 + support2))
    
    new_center_topic = np.zeros(len(center_topic1))

    for i in range(len(center_topic1)):
        if (center_topic1[i] > minimum_selected) or (center_topic2[i] > minimum_selected):
            temp_data = pro1 * center_topic1[i] + pro2 * center_topic2[i]
            new_center_topic[i] = temp_data
            temp_sum += temp_data
    
    #if (temp_sum > 0.0001):  # 有改动
    new_center_topic = np.true_divide(new_center_topic, temp_sum)
    
    return new_center_topic

if __name__ == '__main__':
    s = np.array([[1, 2,3,4,5,6,7,8,11,0,10,9]])
    # a = sparse_DTvsm(s, 3)
    rfc = get_topic_rfc(s, 1)
    print(rfc)
    